{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prompt scheduling callback\n",
    "\n",
    "Prompt scheduling callback allows changing prompts during a generation, like [prompt editing in A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#prompt-editing).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: diffusers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (0.31.0)\n",
      "Requirement already satisfied: torch in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (2.2.1+cu121)\n",
      "Requirement already satisfied: accelerate in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (1.1.1)\n",
      "Requirement already satisfied: transformers in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (4.46.2)\n",
      "Requirement already satisfied: importlib-metadata in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (8.5.0)\n",
      "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (3.16.1)\n",
      "Requirement already satisfied: huggingface-hub>=0.23.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.26.2)\n",
      "Requirement already satisfied: numpy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (1.26.4)\n",
      "Requirement already satisfied: regex!=2019.12.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2024.11.6)\n",
      "Requirement already satisfied: requests in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (2.32.3)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (0.4.5)\n",
      "Requirement already satisfied: Pillow in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from diffusers) (11.0.0)\n",
      "Requirement already satisfied: typing-extensions>=4.8.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (4.12.2)\n",
      "Requirement already satisfied: sympy in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (1.13.3)\n",
      "Requirement already satisfied: networkx in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.4.2)\n",
      "Requirement already satisfied: jinja2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (3.1.4)\n",
      "Requirement already satisfied: fsspec in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2024.10.0)\n",
      "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (8.9.2.26)\n",
      "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.3.1)\n",
      "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.0.2.54)\n",
      "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (10.3.2.106)\n",
      "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (11.4.5.107)\n",
      "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.0.106)\n",
      "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.19.3)\n",
      "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (12.1.105)\n",
      "Requirement already satisfied: triton==2.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from torch) (2.2.0)\n",
      "Requirement already satisfied: nvidia-nvjitlink-cu12 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch) (12.6.77)\n",
      "Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (24.1)\n",
      "Requirement already satisfied: psutil in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (6.1.0)\n",
      "Requirement already satisfied: pyyaml in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from accelerate) (6.0.2)\n",
      "Requirement already satisfied: tokenizers<0.21,>=0.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (0.20.3)\n",
      "Requirement already satisfied: tqdm>=4.27 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from transformers) (4.66.6)\n",
      "Requirement already satisfied: zipp>=3.20 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from importlib-metadata->diffusers) (3.21.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from jinja2->torch) (3.0.2)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.4.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (3.10)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2.2.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests->diffusers) (2024.8.30)\n",
      "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "pip install diffusers torch accelerate transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7b276fdac7b14d1cb493944859cf0334",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Fetching 15 files:   0%|          | 0/15 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8b3c586a8d5746bcb28272273b69b8cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/1.72G [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d4971b6b226341f2a7d9d20fb3bb9f21",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.fp16.safetensors:   0%|          | 0.00/246M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "83a7bf67702d4398b7a3b6ec11622849",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.fp16.safetensors:   0%|          | 0.00/608M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0ce891240d334a2a84233a1f2221757d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "diffusion_pytorch_model.fp16.safetensors:   0%|          | 0.00/167M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "53a731ca1f1a4262acf999500ea81150",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from diffusers import StableDiffusionPipeline\n",
    "from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks\n",
    "from diffusers.configuration_utils import register_to_config\n",
    "import torch\n",
    "from typing import Any, Dict, Optional\n",
    "\n",
    "\n",
    "pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(\n",
    "    \"stable-diffusion-v1-5/stable-diffusion-v1-5\",\n",
    "    torch_dtype=torch.float16,\n",
    "    variant=\"fp16\",\n",
    "    use_safetensors=True,\n",
    ").to(\"cuda\")\n",
    "pipeline.safety_checker = None\n",
    "pipeline.requires_safety_checker = False\n",
    "\n",
    "\n",
    "class SDPromptScheduleCallback(PipelineCallback):\n",
    "    @register_to_config\n",
    "    def __init__(\n",
    "        self,\n",
    "        prompt: str,\n",
    "        negative_prompt: Optional[str] = None,\n",
    "        num_images_per_prompt: int = 1,\n",
    "        cutoff_step_ratio=1.0,\n",
    "        cutoff_step_index=None,\n",
    "    ):\n",
    "        super().__init__(\n",
    "            cutoff_step_ratio=cutoff_step_ratio, cutoff_step_index=cutoff_step_index\n",
    "        )\n",
    "\n",
    "    tensor_inputs = [\"prompt_embeds\"]\n",
    "\n",
    "    def callback_fn(\n",
    "        self, pipeline, step_index, timestep, callback_kwargs\n",
    "    ) -> Dict[str, Any]:\n",
    "        cutoff_step_ratio = self.config.cutoff_step_ratio\n",
    "        cutoff_step_index = self.config.cutoff_step_index\n",
    "\n",
    "        # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio\n",
    "        cutoff_step = (\n",
    "            cutoff_step_index\n",
    "            if cutoff_step_index is not None\n",
    "            else int(pipeline.num_timesteps * cutoff_step_ratio)\n",
    "        )\n",
    "\n",
    "        if step_index == cutoff_step:\n",
    "            prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(\n",
    "                prompt=self.config.prompt,\n",
    "                negative_prompt=self.config.negative_prompt,\n",
    "                device=pipeline._execution_device,\n",
    "                num_images_per_prompt=self.config.num_images_per_prompt,\n",
    "                do_classifier_free_guidance=pipeline.do_classifier_free_guidance,\n",
    "            )\n",
    "            if pipeline.do_classifier_free_guidance:\n",
    "                prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])\n",
    "            callback_kwargs[self.tensor_inputs[0]] = prompt_embeds\n",
    "        return callback_kwargs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bd75ef9ee32b454298a9bee7a1210627",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/50 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "callback = MultiPipelineCallbacks(\n",
    "    [\n",
    "        SDPromptScheduleCallback(\n",
    "            prompt=\"Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski\",\n",
    "            negative_prompt=\"Deformed, ugly, bad anatomy\",\n",
    "            cutoff_step_ratio=0.25,\n",
    "        )\n",
    "    ]\n",
    ")\n",
    "\n",
    "image = pipeline(\n",
    "    prompt=\"Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski\",\n",
    "    negative_prompt=\"Deformed, ugly, bad anatomy\",\n",
    "    callback_on_step_end=callback,\n",
    "    callback_on_step_end_tensor_inputs=[\"prompt_embeds\"],\n",
    ").images[0]\n",
    "\n",
    "torch.cuda.empty_cache()\n",
    "image.save('image2.png')"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
